# src/vol5_k2m_cc/run_geometry_cc.py
from __future__ import annotations

import argparse
import json
import os
import sys
import hashlib
from dataclasses import dataclass
from typing import Any, Dict, Tuple, List

import numpy as np
import yaml

# Local modules (these must exist in your repo)
from .cc_translator import CCTConfig, build_mask
from .potential import PoissonCfg, compute_phi_from_mask, radial_dist
from .metrics import FitCfg, fit_slopes


# ---------------- utilities ----------------

def _sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()

def _sha256_array(arr: np.ndarray) -> str:
    h = hashlib.sha256()
    h.update(np.ascontiguousarray(arr).tobytes(order="C"))
    return h.hexdigest()

def _ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def _load_e0(e0_path: str) -> np.ndarray:
    """Load E0 from .npz (preferred) or .npy. Accept key 'E0' or first array."""
    if not os.path.exists(e0_path):
        raise FileNotFoundError(e0_path)
    if e0_path.endswith(".npz"):
        z = np.load(e0_path, allow_pickle=False)
        key = "E0" if "E0" in z.files else z.files[0]
        a = z[key]
        z.close()
        return a
    elif e0_path.endswith(".npy"):
        return np.load(e0_path, allow_pickle=False)
    else:
        raise ValueError(f"Unsupported E0 file type: {e0_path}")

def _pick_kernel(data_dir: str, gauge: str, L: int) -> str | None:
    """
    Try to locate a kernel under <data_dir>/kernels/<gauge>/ matching L.
    Accepts names like:
      SU2: kernel_L128.npy
      SU3: kernel_SU3_L128.npy
    """
    root = os.path.join(data_dir, "kernels", gauge)
    if not os.path.isdir(root):
        return None
    cand = []
    for fn in os.listdir(root):
        if fn.endswith(".npy") and f"L{L}" in fn:
            cand.append(os.path.join(root, fn))
    if not cand:
        return None
    cand.sort(key=lambda p: len(os.path.basename(p)))  # prefer shorter names
    return cand[0]

def _split_poisson_cfg(poisson_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Split YAML 'poisson:' into PoissonCfg kwargs vs FitCfg kwargs."""
    phi_keys = {"epsilon_soften", "background_subtract", "ring_inner_fracL", "ring_outer_fracL"}
    fit_keys = {"fit_window_min", "fit_window_max_fracL", "radial_bins_scheme", "radial_bins", "regression_weights"}
    phi_kwargs = {k: v for k, v in poisson_dict.items() if k in phi_keys}
    fit_kwargs = {k: v for k, v in poisson_dict.items() if k in fit_keys}
    return phi_kwargs, fit_kwargs

def _as_list(x):
    return x if isinstance(x, list) else [x]

# ---------------- simple lensing from φ ----------------
def _alpha_vs_b_from_phi(phi: np.ndarray, b_min: int, b_max: int, n_b: int,
                         center: Tuple[float, float]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute a simple deflection curve α(b) from φ by averaging the radial
    component of ∇φ on thin rings at radii b. Deterministic, no NaNs.
    Returns (b_values, alpha_values).
    """
    # gradient
    gy, gx = np.gradient(phi)             # central differences
    L = phi.shape[0]
    y, x = np.indices((L, L))
    cy, cx = center
    dy = y - cy
    dx = x - cx
    r = np.hypot(dx, dy) + 1e-9  # avoid div-by-zero

    # radial unit vector and radial component of grad
    urx = dx / r
    ury = dy / r
    grad_rad = gx * urx + gy * ury

    b_vals = np.linspace(float(b_min), float(b_max), int(n_b))
    alphas = []
    for b in b_vals:
        ring = (r >= (b - 0.5)) & (r < (b + 0.5))
        if not np.any(ring):
            alphas.append(0.0)
        else:
            # Mean absolute radial gradient on the ring
            alphas.append(float(np.mean(np.abs(grad_rad[ring]))))
    return b_vals, np.array(alphas, dtype=float)

def _linfit_y_x(y: np.ndarray, x: np.ndarray) -> Tuple[float, float, float]:
    """Least-squares fit y = m*x + c; returns (m, c, R^2)."""
    x = np.asarray(x, float).reshape(-1, 1)
    y = np.asarray(y, float).reshape(-1, 1)
    X = np.hstack([x, np.ones_like(x)])
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)
    m = float(beta[0, 0]); c = float(beta[1, 0])
    yhat = X @ beta
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2)) + 1e-12
    r2 = 1.0 - ss_res / ss_tot
    return m, c, r2

# ---------------- core runner ----------------
def run_one(data_dir: str, out_root: str, meta: Dict[str, Any], cfg: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, str | None]:
    """
    Run 2b compact-curvature on a single anchor (meta).
    Returns (result_dict, err_msg). On success, writes result.json and returns dict + None.
    """
    gauge = str(meta["gauge"])
    L = int(meta["L"])
    b = float(meta["b"])
    kappa = float(meta["kappa"])
    f = float(meta["f"])
    seed = int(meta["seed"])

    # E0 path
    e0_path = os.path.join(
        data_dir, gauge, f"L{L}",
        f"b{b}", f"k{kappa:.2f}", f"f{f:.2f}", f"seed{seed}", "E0.npz"
    )
    try:
        e0 = _load_e0(e0_path)
    except Exception as e:
        return None, f"E0 load failed: {e}"
    if not (isinstance(e0, np.ndarray) and e0.shape == (L, L)):
        return None, f"E0 missing or wrong shape (got {None if not isinstance(e0,np.ndarray) else e0.shape})"

    # Translator → mask
    tcfg = CCTConfig(**cfg["translator"])
    mask_out = build_mask(e0, tcfg)
    # Be tolerant to return shapes
    if isinstance(mask_out, tuple):
        mask = mask_out[0]
        info = mask_out[1] if len(mask_out) > 1 and isinstance(mask_out[1], dict) else {}
    elif isinstance(mask_out, dict):
        mask = mask_out["mask"]
        info = mask_out
    else:
        mask = mask_out
        info = {}
    mask = (mask.astype(bool))
    splus_cov_pct = 100.0 * float(mask.mean())
    sigma_list = getattr(tcfg, "sigma_list", None) or cfg["translator"].get("sigma_list", [])
    splus_sigma_max = float(info.get("sigma_max", max(sigma_list) if sigma_list else 0.0))
    splus_hash = _sha256_array(mask)

    # Poisson + fits
    phi_kwargs, fit_kwargs = _split_poisson_cfg(cfg["poisson"])
    phi_cfg = PoissonCfg(**phi_kwargs)
    fit_cfg = FitCfg(**fit_kwargs)

    phi, grad_mag = compute_phi_from_mask(mask.astype(np.float64), splus_sigma_max, phi_cfg)
    fits = fit_slopes(phi, grad_mag, splus_sigma_max, fit_cfg)

    # Lensing (deterministic, from φ)
    # Center: S+ centroid if available, else image center
    yy, xx = np.nonzero(mask)
    if len(yy) > 0:
        cy = float(np.mean(yy)); cx = float(np.mean(xx))
    else:
        cy = cx = (L - 1) / 2.0
    bmin = int(cfg.get("optics", {}).get("lensing_b_min", 12))
    bmax = int(cfg.get("optics", {}).get("lensing_b_max", 64))
    bn   = int(cfg.get("optics", {}).get("lensing_b_n", 32))

    b_vals, alpha_vals = _alpha_vs_b_from_phi(phi, bmin, bmax, bn, center=(cy, cx))
    x = 1.0 / np.maximum(b_vals, 1e-6)

    lam_list = cfg.get("optics", {}).get("lambda_sweep", [0.2, 0.5, 1.0])
    alpha_fields = {}
    best_lambda = None
    best_r2 = -np.inf
    for lam in lam_list:
        y = lam * alpha_vals
        m, c, r2 = _linfit_y_x(y, x)
        alpha_fields[f"alpha_slope_lambda{lam}"] = float(m)
        alpha_fields[f"alpha_r2_lambda{lam}"] = float(r2)
        if r2 > best_r2:
            best_r2 = r2
            best_lambda = f"lambda_{lam}"
    if best_r2 == -np.inf:
        best_r2 = float("nan")

    # Kernel metadata (optional)
    kernel_path = _pick_kernel(data_dir, gauge, L)
    if kernel_path and os.path.exists(kernel_path):
        try:
            karr = np.load(kernel_path, allow_pickle=False)
            kernel_shape = list(karr.shape)
            kernel_dtype = str(karr.dtype)
            kernel_hash = _sha256_file(kernel_path)
        except Exception:
            kernel_shape = None
            kernel_dtype = None
            kernel_hash = None
    else:
        kernel_shape = None
        kernel_dtype = None
        kernel_hash = None

    # Assemble result
    res = {
        "gauge": gauge, "L": L, "b": b, "kappa": kappa, "f": f, "seed": seed,
        "e0_path": e0_path,
        "e0_hash": _sha256_file(e0_path),
        "kernel_path": kernel_path,
        "kernel_shape": kernel_shape,
        "kernel_dtype": kernel_dtype,
        "kernel_hash": kernel_hash,
        "splus_cov_pct": splus_cov_pct,
        "splus_sigma_max": splus_sigma_max,
        "splus_hash": splus_hash,
        **fits,
        **alpha_fields,
        "alpha_best_lambda": best_lambda,
        "alpha_best_r2": float(best_r2),
    }

    # --- write per-run subfolder (so multiple runs don't overwrite) ---
    run_dir = os.path.join(
        out_root, gauge, f"L{L}", f"b{b}",
        f"k{kappa:.2f}", f"f{f:.2f}", f"seed{seed}"
    )
    _ensure_dir(run_dir)
    out_json = os.path.join(run_dir, "result.json")
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(res, f, indent=2)

    return res, None


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", required=True)
    ap.add_argument("--data-dir", required=True)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    with open(args.config, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)

    anchors = cfg.get("anchors", [])
    if not anchors:
        print("[ERR] no anchors in config", file=sys.stderr)
        sys.exit(2)

    ok = 0
    fail = 0

    for anchor in anchors:
        # --- support L as a list (iterate all) ---
        L_list   = _as_list(anchor.get("L", []))
        b_list   = _as_list(anchor.get("b", []))
        k_list   = _as_list(anchor.get("kappa", []))
        f_list   = _as_list(anchor.get("f", []))
        seed_list = _as_list(anchor.get("seeds", anchor.get("seed", [])))

        gauge = str(anchor["gauge"])

        for L in L_list:
            L = int(L)
            for b in b_list:
                for kappa in k_list:
                    for f in f_list:
                        for seed in seed_list:
                            meta = {
                                "gauge": gauge, "L": L, "b": float(b),
                                "kappa": float(kappa), "f": float(f), "seed": int(seed),
                            }
                            try:
                                res, err = run_one(args.data_dir, args.out, meta, cfg)
                                if err:
                                    print(f"[FAIL] {meta}\n {err}")
                                    fail += 1
                                else:
                                    print(f"[OK  ] {res['e0_path']}")
                                    ok += 1
                            except Exception as e:
                                print(f"[FAIL] {meta}\n {e}")
                                fail += 1

    print(f"[done] ok={ok} fail/skip={fail} out_dir={args.out}")


if __name__ == "__main__":
    main()
